import math
from collections import OrderedDict
import torch
import torch.nn as nn

"""
残差结构
利用一个1x1卷积下降通道数，然后利用一个3x3卷积提取特征并且上升通道数
最后接上一个残差边
"""
class BasicBlock(nn.Module):
    def __init__(self, inplanes, planes):
        super(BasicBlock, self).__init__()

        # kernel_size卷积核大小，stride步长，padding图像扩展，bias偏置
        self.conv1 = nn.Conv2d(inplanes, planes[0], kernel_size=1,
                               stride=1, padding=0, bias=False)  # nn.Conv2d将用于在由多个输入平面组成的输入信号上应用2D卷积
        self.bn1 = nn.BatchNorm2d(planes[0])  # 在4D上应用批量归一化
        self.relu1 = nn.LeakyReLU(0.1)  # LeakyReLu(x)= max(0, x)+ negative_slope * min(0, x)

        self.conv2 = nn.Conv2d(planes[0], planes[1], kernel_size=3,
                               stride=1, padding=1, bias=False)
        self.bn2 = nn.BatchNorm2d(planes[1])
        self.relu2 = nn.LeakyReLU(0.1)

    def forward(self, x):
        residual = x

        out = self.conv1(x)
        out = self.bn1(out)
        out = self.relu1(out)

        out = self.conv2(out)
        out = self.bn2(out)
        out = self.relu2(out)

        out += residual
        return out


class DarkNet(nn.Module):
    def __init__(self, layers):
        """
        :param layers: 每个残差块中的残差单元的数量
        """
        super(DarkNet, self).__init__()
        self.inplanes = 32
        # DBL: 416,416,3 -> 416,416,32
        self.conv1 = nn.Conv2d(3, self.inplanes, kernel_size=3, stride=1, padding=1, bias=False)
        self.bn1 = nn.BatchNorm2d(self.inplanes)
        self.relu1 = nn.LeakyReLU(0.1)
        # layers=[1, 2, 8, 8, 4]
        # Res1: 416,416,32 -> 208,208,64
        self.layer1 = self._make_resn_layer([32, 64], layers[0])
        # Res2: 208,208,64 -> 104,104,128
        self.layer2 = self._make_resn_layer([64, 128], layers[1])
        # Res8: 104,104,128 -> 52,52,256
        self.layer3 = self._make_resn_layer([128, 256], layers[2])
        # Res8: 52,52,256 -> 26,26,512
        self.layer4 = self._make_resn_layer([256, 512], layers[3])
        # Res4: 26,26,512 -> 13,13,1024
        self.layer5 = self._make_resn_layer([512, 1024], layers[4])

        self.layers_out_filters = [64, 128, 256, 512, 1024]

        # 模型权值初始化
        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
                m.weight.data.normal_(0, math.sqrt(2. / n))
            elif isinstance(m, nn.BatchNorm2d):
                m.weight.data.fill_(1)
                m.bias.data.zero_()

    """
    在每一个layer里面，首先利用一个步长为2的3x3卷积进行下采样
    然后进行残差结构的堆叠
    """
    def _make_resn_layer(self, planes, blocks):
        # 下采样，步长为2，卷积核大小为3(Zero_padding + DBL)
        layers = [("ds_conv", nn.Conv2d(self.inplanes, planes[1], kernel_size=3,
                                        stride=2, padding=1, bias=False)), ("ds_bn", nn.BatchNorm2d(planes[1])),
                  ("ds_relu", nn.LeakyReLU(0.1))]
        # 加入Res Unit
        self.inplanes = planes[1]
        for i in range(0, blocks):
            layers.append(("residual_{}".format(i), BasicBlock(self.inplanes, planes)))
        return nn.Sequential(OrderedDict(layers))

    def forward(self, x):
        # DBL
        x = self.conv1(x)
        x = self.bn1(x)
        x = self.relu1(x)

        x = self.layer1(x)
        x = self.layer2(x)
        # 第一个Res8后的输出
        out3 = self.layer3(x)
        # 第二个Res8后的输出
        out4 = self.layer4(out3)
        # Res4后的输出
        out5 = self.layer5(out4)

        return out3, out4, out5


def darknet53(pretrained, **kwargs):
    model = DarkNet([1, 2, 8, 8, 4])
    if pretrained:
        if isinstance(pretrained, str):
            model.load_state_dict(torch.load(pretrained))
        else:
            raise Exception("darknet request a pretrained path. got [{}]".format(pretrained))
    return model
